{% extends 'base.exported.class' %}

{% block content %}
import com.google.gson.Gson;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.Scanner;
{% if is_test or to_json %}
import java.util.Arrays;


{% endif %}
class {{ class_name }} {

    private enum Activation { IDENTITY, LOGISTIC, RELU, TANH, SOFTMAX }

    private class Classifier {
        private String hidden_activation;
        private Activation hidden;
        private String output_activation;
        private Activation output;
        private double[][] network;
        private double[][][] weights;
        private double[][] bias;
        private int[] layers;
    }

    private Classifier clf;

    public {{ class_name }}(String file) throws FileNotFoundException {
        String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
        this.clf = new Gson().fromJson(jsonStr, Classifier.class);
        this.clf.network = new double[this.clf.layers.length + 1][];
        for (int i = 0, l = this.clf.layers.length; i < l; i++) {
            this.clf.network[i + 1] = new double[this.clf.layers[i]];
        }
        this.clf.hidden = Activation.valueOf(this.clf.hidden_activation.toUpperCase());
        this.clf.output = Activation.valueOf(this.clf.output_activation.toUpperCase());
    }

    private int findMax(double[] nums) {
        int i, l = nums.length, idx = 0;
        for (i = 0; i < l; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    }

    private double[] compute(Activation activation, double[] v) {
        int i, l = v.length;
        switch (activation) {
            case LOGISTIC:
                for (i = 0; i < l; i++) {
                    v[i] = 1. / (1. + Math.exp(-v[i]));
                }
                break;
            case RELU:
                for (i = 0; i < l; i++) {
                    v[i] = Math.max(0, v[i]);
                }
                break;
            case TANH:
                for (i = 0; i < l; i++) {
                    v[i] = Math.tanh(v[i]);
                }
                break;
            case SOFTMAX:
                double max = Double.NEGATIVE_INFINITY;
                for (double x : v) {
                    if (x > max) {
                        max = x;
                    }
                }
                for (i = 0; i < l; i++) {
                    v[i] = Math.exp(v[i] - max);
                }
                double sum = 0.;
                for (double x : v) {
                    sum += x;
                }
                for (i = 0; i < l; i++) {
                    v[i] /= sum;
                }
                break;
        }
        return v;
    }

    private void resetNetwork() {
        for (int i = 1, l = this.clf.network.length - 1; i < l; i++) {
            for (int j = 0; j < this.clf.network[i].length; j++) {
                this.clf.network[i][j] = 0;
            }
        }
    }

    private double[] feedForward(double[] neurons) {
        this.clf.network[0] = neurons;
        for (int i = 0; i < this.clf.network.length - 1; i++) {
            for (int j = 0; j < this.clf.network[i + 1].length; j++) {
                this.clf.network[i + 1][j] = this.clf.bias[i][j];
                for (int l = 0; l < this.clf.network[i].length; l++) {
                    this.clf.network[i + 1][j] += this.clf.network[i][l] * this.clf.weights[i][l][j];
                }
            }
            if ((i + 1) < (this.clf.network.length - 1)) {
                this.clf.network[i + 1] = this.compute(this.clf.hidden, this.clf.network[i + 1]);
            }
        }
        this.clf.network[this.clf.network.length - 1] = this.compute(this.clf.output, this.clf.network[this.clf.network.length - 1]);
        this.resetNetwork();
        return this.clf.network[this.clf.network.length - 1];
    }

    public int predict(double[] neurons) {
        double[] lastLayer = this.feedForward(neurons);
        if (lastLayer.length == 1) {
            if (lastLayer[0] > .5) {
                return 1;
            }
            return 0;
        }
        return findMax(lastLayer);
    }

    public double[] predictProba(double[] neurons) {
        double[] lastLayer = this.feedForward(neurons);
        if (lastLayer.length == 1) {
            return new double[] {lastLayer[0], 1 - lastLayer[0]};
        }
        return lastLayer;
    }

    public static void main(String[] args) throws FileNotFoundException {
        int nFeatures = {{ n_features }};
        if (args.length != (nFeatures + 1) || !args[0].endsWith(".json")) {
            throw new IllegalArgumentException("You have to pass the path to the exported model data and " +  String.valueOf(nFeatures) + " features.");
        }

        // Features:
        double[] features = new double[args.length-1];
        for (int i = 1, l = args.length; i < l; i++) {
            features[i - 1] = Double.parseDouble(args[i]);
        }

        // Model data:
        String modelData = args[0];

        // Estimator:
        {{ class_name }} clf = new {{ class_name }}(modelData);

        {% if is_test or to_json %}
        // Get JSON:
        int prediction = clf.predict(features);
        double[] probabilities = clf.predictProba(features);
        System.out.println("{\"predict\": " + String.valueOf(prediction) + ", \"predict_proba\": " + String.join(",", Arrays.toString(probabilities)) + "}");
        {% else %}
        // Get class prediction:
        int prediction = clf.predict(features);
        System.out.println("Predicted class: #" + String.valueOf(prediction));

        // Get class probabilities:
        double[] probabilities = clf.predictProba(features);
        for (int i = 0; i < probabilities.length; i++) {
            System.out.print(String.valueOf(probabilities[i]));
            if (i != probabilities.length - 1) {
                System.out.print(",");
            }
        }
        {% endif %}

    }
}
{% endblock %}